Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the condition for calling update_learning_rates #7032

Merged
merged 7 commits into from
May 17, 2021

Conversation

tkng
Copy link
Contributor

@tkng tkng commented Apr 15, 2021

What does this PR do?

This pull request fixes the condition for update_learning_rates to be called in training_loop.py.

update_learning_rates in the OptimizerConnector class is called from several places. In the training_loop.py, one of the calling conditions is missing a not.

As a result, when Trainer's check_val_every_n_epoch is set to a value greater than 2, the learning rate is not updated as expected. For example, if check_val_every_n_epoch is set to 2, the learning rate will be updated only once every 2 epochs.

more detail

The current condition for calling update_learning_rates is as follows.

if (val_loop_called and not should_check_val) or should_train_only

The comment above the line says "update epoch level lr_schedulers if no val loop outside train loop is triggered". Current code doesn't work as the comment expects.

This pull request fixes this issue by inserting not before the val_loop_called.

Fixes #6616

By the way, we can assume that update_learning_rates does not have any side effects (in fact, I check the current behavior of PyTorch-Lightning by writing a custom learning rate rule by myself. Currently, it is called twice in one epoch). So, I think we can call it every time in the training loop without thinking anything difficult. However, since I don't understand the whole code of this project, I kept the amount of changes in this pull request as small as possible.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@tkng tkng marked this pull request as ready for review April 15, 2021 05:43
@codecov
Copy link

codecov bot commented Apr 15, 2021

Codecov Report

Merging #7032 (6d8fff3) into master (7ca4173) will decrease coverage by 4%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #7032    +/-   ##
=======================================
- Coverage      92%     88%    -4%     
=======================================
  Files         196     196            
  Lines       12835   12833     -2     
=======================================
- Hits        11832   11268   -564     
- Misses       1003    1565   +562     

@carmocca
Copy link
Contributor

Can you write a test that fails on current master without the fix?

@tkng
Copy link
Contributor Author

tkng commented Apr 16, 2021

I'm sorry that I have no idea how to write a test for this issue...

Even if update_learning_rates is not called, training itself will proceed. Also, I can't think of a good way to check that update_learning_rates is not being called, from outside the Trainer.

@awaelchli
Copy link
Contributor

@tkng You must have observed this bug somehow in your code right? If you provide a minimal example (see our bug_report_model), we can easily help you convert it to a good regression test :)

@tkng
Copy link
Contributor Author

tkng commented Apr 17, 2021

@awaelchli thanks, I wrote a minimal example and uploaded to https://gist.github.com/tkng/a89ffdab2558aedc04a67f7142e6aeb0. (Sorry, I noticed to bug_report_model.py after I wrote this example, so the code is based on lightning-flash.)

Once you run this example and opened the log directory by tensorboard, you'll see this learning rate decay. Pink one is with current pytorch-lightning, green one is with the patch attached on this PR.

1618669789

@awaelchli awaelchli added bug Something isn't working priority: 1 Medium priority task labels Apr 17, 2021
@awaelchli awaelchli added this to the 1.3 milestone Apr 17, 2021
@awaelchli
Copy link
Contributor

Thanks. Based on your description and code, I pushed a minimal test that fails on master.

@awaelchli awaelchli self-assigned this Apr 17, 2021
@ananthsub
Copy link
Contributor

Why is the learning rate update dependent on the validation loop at all?

@awaelchli
Copy link
Contributor

I suspect because of this scheduler
https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
But I'm not an expert on this section of the code. I believe @rohitgr7 has worked a bit here if I'm not mistaken.

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Borda Borda added the ready PRs ready to be merged label Apr 18, 2021
@Borda Borda enabled auto-merge (squash) April 18, 2021 09:48
@rohitgr7 rohitgr7 disabled auto-merge April 18, 2021 11:39
@rohitgr7
Copy link
Contributor

well, the solution here is a bit incorrect. val_loop_called actually tells whether the val_loop is called within the training loop or not for the cases where val_check_interval(float) < 1. or val_check_interval < num_train_batches. In such a case val_loop_called will be True and thus we need to ensure to explicitly call the epoch level schedulers explicitly outside the train loop if there is no epoch end validation going to happen. I have updated a test #7084 which will fail with your changes.

Epoch level schedulers also update within the evaluation loop because we need to ensure they get updated after the validation loop and right before checkpointing to ensure the correct scheduler state dict being saved in the checkpoint file. I know this is a bit complicated but most of the complexity is due to this ReduceLROnPlateau scheduler which needs a monitor to update itself. This monitor can belong to anything logged in either train or val.

To handle the cases where check_val_every_n_epoch > 1, I think we need to somehow pull out the scheduler update from val_loop and still need to ensure that correct scheduler state_dict gets updated within the checkpoint file (we have tests for that) and ofcourse should work with ReduceLROnPlateau (I think we have tests for that too).

@rohitgr7 rohitgr7 removed the ready PRs ready to be merged label Apr 18, 2021
@tkng
Copy link
Contributor Author

tkng commented Apr 20, 2021

@rohitgr7 Hmm, the problem seems more complicated than I had thought...

May I assume that update_learning_rates should always be called at the end of one training epoch, even for ReduceLROnPlateau?

If the assumption is correct, then I think the correct code would be to call update_learning_rate immediately after the end of the training process for one epoch, in the run_training_epoch method of TrainLoop class.

I'm thinking this understanding is consistent with your statement, I think we need to somehow pull out the scheduler update from val_loop and still need to ensure that correct scheduler state_dict gets updated within the checkpoint file.

I tried to write such a code once, but I could not make the unit test pass. On the other hand, if I discarded the above understanding and just considered passing the unit test, I only needed to change one line, as I just pushed. (Just for sure, this also passes the test for #7084.) However, I am not sure whether my current code is correct. 🤔

@rohitgr7
Copy link
Contributor

yeap val_loop_called is redundant since should_check_val can cover up that case. Calling lr_update right after the train loop will create problem with ReduceLROnPlateau since the monitor might not be logged by then. Also even if we do it after the validation loop then checkpoint won't be getting the correct lr_state_dict. I think I have a solution for this. Will try it on weekend.

@carmocca
Copy link
Contributor

Converting to draft so merging is blocked until these comments are addressed 👍

@carmocca carmocca marked this pull request as draft April 21, 2021 21:06
@edenlightning edenlightning removed this from the v1.3 milestone May 4, 2021
@Borda
Copy link
Member

Borda commented May 11, 2021

@tkng how is it going here, still WIP or ready to go? 🐰

@tkng
Copy link
Contributor Author

tkng commented May 13, 2021

@Borda This pull request passes the current all unit tests, but according to @rohitgr7, it is incompatible with ReduceLROnPlateau.

I have tried to rewrite this PR to compatible with ReduceLROnPlateau, but I have failed to write such code that passes all tests and still conforms to ReduceLROnPlateau. I think we have two options.

  1. wait for new code from @rohitgr7
  2. merge this pull request, since it passes all the current unit tests

I am OK with either, but we'll need a opinion from @rohitgr7.

@rohitgr7 rohitgr7 force-pushed the bugfix/update_learning_rates branch from 0a1cd5e to 76cb58d Compare May 13, 2021 19:09
@rohitgr7
Copy link
Contributor

rohitgr7 commented May 13, 2021

I tried something, but couldn't find an optimal solution that can cover all the cases because of this ReduceLROnPlateau. The problem is if a monitor is used with this scheduler with check_val_every_n_epoch > 1 then there is no way to decide whether to do a scheduler.step after validation or after every train loop because this monitor can be anything and can be logged anywhere. For eg check_val_every_n_epoch>1, if it's logged in val loop then we ideally shouldn't do scheduler.step after every epoch but if it's logged in train loop then we should. Although one can easily configure this correctly by setting scheduler['frequency'] accordingly if they understand how they have configured their train/val loop using trainer parameters.

So I'd say merge this one.

max_epochs=epochs,
)
trainer.fit(model)
assert mocked_sched.call_count == expected_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: why do we want to update schedulers all epochs even if we only run validation some of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running validation on some epoch is just a choice. Ideally, a scheduler should update after every epoch if scheduler['frequency] = 1(default) and scheduler['interval'] == 'epoch'. If someone wants it to align it with validation, they can set scheduler['frequency'] = check_val_every_n_epoch.

@carmocca carmocca merged commit 20f6337 into Lightning-AI:master May 17, 2021
@carmocca carmocca mentioned this pull request May 17, 2021
edgarriba pushed a commit that referenced this pull request May 18, 2021
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Borda pushed a commit that referenced this pull request May 18, 2021
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
lexierule pushed a commit that referenced this pull request May 19, 2021
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@carmocca carmocca added this to the v1.3.x milestone May 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: 1 Medium priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Epoch level scheduler is not well called except check_val_every_n_epoch == 1
7 participants